import torch
import torch.nn as nn
import torch.nn.functional as F


class lossAV(nn.Module):

    def __init__(self):
        super(lossAV, self).__init__()
        self.criterion = nn.CrossEntropyLoss(reduction='none')
        self.FC = nn.Linear(128, 2)

    def forward(self, x, labels=None, masks=None):
        x = x.squeeze(1)
        x = self.FC(x)
        if labels is None:
            predScore = x[:, 1]
            predScore = predScore.t()
            predScore = predScore.view(-1).detach().cpu().numpy()
            return predScore
        else:
            nloss = self.criterion(x, labels) * masks
            num_valid = masks.sum().float()
            nloss = torch.sum(nloss) / (num_valid + 1e-6)  # 避免除 0

            predScore = F.softmax(x, dim=-1)
            predLabel = torch.argmax(predScore, dim=-1)  # 预测类别
            correctNum = ((predLabel == labels) * masks).sum().float()
            return nloss, predScore, predLabel, correctNum


class lossA(nn.Module):

    def __init__(self):
        super(lossA, self).__init__()
        self.criterion = nn.CrossEntropyLoss(reduction='none')
        self.FC = nn.Linear(128, 2)

    def forward(self, x, labels, masks=None):
        x = x.squeeze(1)
        x = self.FC(x)
        nloss = self.criterion(x, labels) * masks
        num_valid = masks.sum().float()
        nloss = torch.sum(nloss) / (num_valid + 1e-6)
        return nloss


class lossV(nn.Module):

    def __init__(self):
        super(lossV, self).__init__()
        self.criterion = nn.CrossEntropyLoss(reduction='none')
        self.FC = nn.Linear(128, 2)

    def forward(self, x, labels, masks=None):
        x = x.squeeze(1)
        x = self.FC(x)
        nloss = self.criterion(x, labels) * masks
        num_valid = masks.sum().float()
        nloss = torch.sum(nloss) / (num_valid + 1e-6)
        return nloss




# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import utils.distributed as du


# class lossAV(nn.Module):

#     def __init__(self):
#         super(lossAV, self).__init__()
#         self.criterion = nn.CrossEntropyLoss(reduction='none')
#         self.FC = nn.Linear(256, 2)

#     def forward(self, x, labels=None, masks=None):
#         x = x.squeeze(1)
#         x = self.FC(x)
#         if labels == None:
#             predScore = x[:, 1]
#             predScore = predScore.t()
#             predScore = predScore.view(-1).detach().cpu().numpy()
#             return predScore
#         else:
#             nloss = self.criterion(x, labels) * masks

#             num_valid = masks.sum().float()
#             if self.training:
#                 [num_valid] = du.all_reduce([num_valid],average=True)
#             nloss = torch.sum(nloss) / num_valid

#             predScore = F.softmax(x, dim=-1)
#             predLabel = torch.round(F.softmax(x, dim=-1))[:, 1]
#             correctNum = ((predLabel == labels) * masks).sum().float()
#             return nloss, predScore, predLabel, correctNum


# class lossA(nn.Module):

#     def __init__(self):
#         super(lossA, self).__init__()
#         self.criterion = nn.CrossEntropyLoss(reduction='none')
#         self.FC = nn.Linear(128, 2)

#     def forward(self, x, labels, masks=None):
#         x = x.squeeze(1)
#         x = self.FC(x)
#         nloss = self.criterion(x, labels) * masks
#         num_valid = masks.sum().float()
#         if self.training:
#             [num_valid] = du.all_reduce([num_valid],average=True)
#         nloss = torch.sum(nloss) / num_valid
#         #nloss = torch.sum(nloss) / torch.sum(masks)
#         return nloss


# class lossV(nn.Module):

#     def __init__(self):
#         super(lossV, self).__init__()

#         self.criterion = nn.CrossEntropyLoss(reduction='none')
#         self.FC = nn.Linear(128, 2)

#     def forward(self, x, labels, masks=None):
#         x = x.squeeze(1)
#         x = self.FC(x)
#         nloss = self.criterion(x, labels) * masks
#         # nloss = torch.sum(nloss) / torch.sum(masks)
#         num_valid = masks.sum().float()
#         if self.training:
#             [num_valid] = du.all_reduce([num_valid],average=True)
#         nloss = torch.sum(nloss) / num_valid
#         return nloss
